#!/bin/bash
set -Eeo pipefail

task_with_command=("$@")
native_mpi_rank=$OMPI_COMM_WORLD_RANK
mpi_rank=${SLURM_PROCID:-${OMPI_COMM_WORLD_RANK:-${PMI_RANK:-${PMI_ID:-0}}}}

log_stderr() { echo -e "\033[33m$@\033[0m" >&2; }
log_stderr "mpi_rank: $mpi_rank"

pid=$(ps -o pid= -p $$ | tr -d ' ')

# Tell TRTLLM to spawn a additional process for the Proxy
export TLLM_SPAWN_PROXY_PROCESS=1

function mpi_world_size {
    if [ -n "$SLURM_NTASKS" ]; then
        echo "$SLURM_NTASKS"
    elif [ -n "$OMPI_COMM_WORLD_SIZE" ]; then
        echo "$OMPI_COMM_WORLD_SIZE"
    else
        echo "1"
    fi
}

function maybe_export_free_tcp_addr_for_spawn_proxy_process {
    # use user specified address if provided
    if [ -n "$TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR" ]; then
        log_stderr "Using user-provided TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR: $TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR"
        return
    fi

    # Generate unique IPC address without importing tensorrt_llm to avoid MPI initialization conflicts
    local free_port=$(python3 -c "import uuid, tempfile, os; print(f'ipc://{os.path.join(tempfile.gettempdir(), \"rpc_test_\" + str(uuid.uuid4()))}')")
    export TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR=$free_port
    log_stderr "TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR: $TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR"
}


export tllm_mpi_size=$(mpi_world_size)
log_stderr "tllm_mpi_size: $tllm_mpi_size"

export TLLM_SPAWN_PROXY_PROCESS_IPC_HMAC_KEY=$(openssl rand -hex 32)

if [ -z "$mpi_rank" ] || [ "$mpi_rank" -eq 0 ]; then

    # IPC only works on localhost and in MPI rank0 process
    maybe_export_free_tcp_addr_for_spawn_proxy_process

    log_stderr "Rank${mpi_rank} run ${task_with_command[@]} in background"

    # MPI doesn't allow spawn a process sharing the MPI environment in a MPI
    # process, or duplicate MPI_Init in the child process will cause undefined
    # behavior. Thus we need to clean the MPI environment in the parent process
    # before spawning the child process, and restore the MPI environment later
    # before running MPI operations in the parent process.
    mpi_blacklist=(
        OMPI_ PMIX_ PMI_ SLURM_ MPI_ UCX_
        I_MPI_ HYDRA_ KMP_ MPICH_ MV2_ CRAY_
    )

    (
        # Remove MPI-related variables only in the subshell context
        for var in $(compgen -e); do
            for prefix in "${mpi_blacklist[@]}"; do
                if [[ "$var" == "$prefix"* ]]; then
                    unset "$var"
                    break
                fi
            done
        done

        # Turn off "exit on error" so the following lines always run
        set +e

        # Execute the task with cleaned environment
        "${task_with_command[@]}"
        task_exit_code=$?
        log_stderr "Rank${mpi_rank} Task exit code: $task_exit_code"

        # Stop the MPI Comm server
        python3 -m tensorrt_llm.llmapi.mgmn_leader_node --action stop
        mpi_exit_code=$?
        log_stderr "Rank${mpi_rank} MPI Comm server exit code: $mpi_exit_code"

        # Propagate task exit status
        if [ $task_exit_code -ne 0 ]; then
            exit $task_exit_code
        else
            exit $mpi_exit_code
        fi
    ) &

    # Turn off "exit on error" so the following lines always run
    set +e

    # Capture subshell PID
    subshell_pid=$!
    log_stderr "Rank${mpi_rank} Subshell PID: $subshell_pid"

    log_stderr "Rank${mpi_rank} run mgmn leader node with mpi_world_size: $(mpi_world_size) ..."
    log_stderr "Rank0 host: $HOSTNAME"
    python3 -m tensorrt_llm.llmapi.mgmn_leader_node
    mgmn_leader_node_exit_code=$?
    log_stderr "Rank${mpi_rank} MGMN leader node exit code: $mgmn_leader_node_exit_code"

    # Wait for subshell
    wait $subshell_pid
    # This is subshell's exit code
    subshell_exit_code=$?
    log_stderr "Rank${mpi_rank} Subshell exit code: $subshell_exit_code"

    # Propagate subshell exit status
    if [ $subshell_exit_code -ne 0 ]; then
        exit $subshell_exit_code
    else
        exit $mgmn_leader_node_exit_code
    fi
else
    # Turn off "exit on error" so the following lines always run
    set +e

    log_stderr "Rank${mpi_rank} run mgmn worker node with mpi_world_size: $(mpi_world_size) ..."
    python3 -m tensorrt_llm.llmapi.mgmn_worker_node
    mgmn_worker_node_exit_code=$?
    log_stderr "Rank${mpi_rank} MGMN worker node exit code: $mgmn_worker_node_exit_code"

    exit $mgmn_worker_node_exit_code
fi
